#!/usr/bin/env python
# -*- coding: UTF-8 -*-

"""
Combine vector representations generated by RNN/GRU/LSTM networks
"""
from argparse import ArgumentParser
import glob
import numpy as np
import os
import pandas as pd
import shutil
import sys
import time

import torch

start_time = time.time()

from .utils import read_input_file
from .utils import sort_key
from .utils import read_command_combinevecs

# --- set seed for reproducibility
from .utils import set_seed_everywhere

set_seed_everywhere(1364)

# ------------------- combine_vecs --------------------
def combine_vecs(
    input_file_path="default",
    rnn_passes=["fwd", "bwd"],
    input_scenario="test",
    output_scenario="test",
    print_every=500,
    sel_device="default",
    save_df=True,
    verbose=True,
):
    """
    Assemble vectors stored in input_scenario and save them in
    output_scenario.

    Parameters
    ----------
    rnn_passes
        RNN/GRU/LSTM passes to be used in assembling vectors (fwd or bwd)
    input_scenario
        name of the input top-directory
    output_scenario
        name of the output top-directory
    input_file_path
        path to the input file. "default": read input file in `input_scenario`
    print_every
        interval to print the progress in assembling vectors
    sel_device
        set the device (cpu, cuda, cuda:0, cuda:1, ...).
        if "default", the device will be read from the input file.
    save_df
        save strings of the first column in queries/candidates files (default: True)
    verbose
        verbose if True (default)
    """

    if type(rnn_passes) in [str]:
        rnn_passes = rnn_passes.split(",")

    rnn_passes = [x.strip() for x in rnn_passes]

    for rnn_pass in rnn_passes:
        if not os.path.isdir(input_scenario):
            sys.exit(f"Directory does not exist: {input_scenario}")

        # paths to create tensors/arrays
        if not os.path.isdir(output_scenario):
            os.makedirs(output_scenario)

        path2vecs = os.path.join(input_scenario, f"embeddings", "rnn_" + rnn_pass + "*")
        path2ids = os.path.join(input_scenario, f"embeddings", "rnn_indxs*")
        pathdf = os.path.join(input_scenario, f"dataframe.df")
        if not os.path.isfile(pathdf):
            sys.exit(f"File does not exit: {pathdf}")

        path_vec_combined = os.path.join(output_scenario, f"{rnn_pass}.pt")
        path_id_combined = os.path.join(output_scenario, f"{rnn_pass}_id.pt")
        path_items_combined = os.path.join(output_scenario, f"{rnn_pass}_items.npy")
        inp_par_dir = os.path.join(input_scenario)

        if sel_device in ["default"]:
            if input_file_path in ["default"]:
                found_input = False
                detect_input_files = glob.iglob(os.path.join(inp_par_dir, "*.yaml"))
                for detected_inp in detect_input_files:
                    if os.path.isfile(detected_inp):
                        input_file_path = detected_inp
                        found_input = True
                        break
                if not found_input:
                    sys.exit(
                        f"[ERROR] no input file (*.yaml file) could be found in the dir: {inp_par_dir}"
                    )

            shutil.copy2(input_file_path, output_scenario)
            dl_inputs = read_input_file(input_file_path)
            sel_device = dl_inputs["general"]["device"]

        if verbose:
            print("\n\n-- Combine vectors")
            print(f"Reading vectors from {path2vecs}")
        list_files = glob.glob(os.path.join(path2vecs))
        list_files.sort(key=sort_key)
        vecs = []
        for i, lfile in enumerate(list_files):
            if verbose:
                if i % print_every == 0:
                    print("%07i" % i, lfile)
            if len(vecs) == 0:
                vecs = torch.load(f"{lfile}", map_location=sel_device)
            else:
                vecs = torch.cat(
                    (vecs, torch.load(f"{lfile}", map_location=sel_device))
                )

        # Save combined vectors
        torch.save(vecs, path_vec_combined)
        del vecs

        if verbose:
            print("\n\n-- Combine IDs\n")
        list_files = glob.glob(os.path.join(path2ids))
        list_files.sort(key=sort_key)
        vecs_ids = []
        for i, lfile in enumerate(list_files):
            if verbose:
                if i % print_every == 0:
                    print("%07i" % i, lfile)
            if len(vecs_ids) == 0:
                vecs_ids = torch.load(f"{lfile}", map_location=sel_device)
            else:
                vecs_ids = torch.cat((vecs_ids, torch.load(f"{lfile}", sel_device)))

        # Save combined IDs
        torch.save(vecs_ids, path_id_combined)
        del vecs_ids

        if save_df:
            # Save strings of the first column in queries/candidates files
            mydf = pd.read_pickle(pathdf)
            vecs_items = mydf[["s1_unicode", "s1"]].to_numpy()
            np.save(path_items_combined, vecs_items)

    if verbose:
        print("--- %s seconds ---" % (time.time() - start_time))


def main():
    # --- read args from the command line
    (
        input_scenario,
        rnn_passes,
        output_scenario,
        input_file_path,
    ) = read_command_combinevecs()

    # ---
    combine_vecs(
        input_file_path=input_file_path,
        rnn_passes=rnn_passes,
        input_scenario=input_scenario,
        output_scenario=output_scenario,
    )


if __name__ == "__main__":
    main()
